
import os
import warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

warnings.filterwarnings("ignore")

import re
import json
import torch
import heapq
import string
import numpy as np
from tqdm import tqdm
from FAdo.fa import *
from FAdo.reex import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import torch
import pickle
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer

device = "cuda:0"

MODEL_DIR = "models/sto_ilm"

def load_model_and_tokenizer(model_dir: str, device: str = None):

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    config_path = os.path.join(model_dir, "config.json")
    config = GPT2Config.from_json_file(config_path)

    model = GPT2LMHeadModel(config)
    state_dict = torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location="cpu")
    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    if missing:
        print(f"⚠️  Missing keys: {missing}")
    if unexpected:
        print(f"⚠️  Unexpected keys: {unexpected}")

    model.to(device)
    model.eval()

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    pkl_path = os.path.join(model_dir, "additional_ids_to_tokens.pkl")
    new_tokens = []
    if os.path.exists(pkl_path):
        with open(pkl_path, "rb") as f:
            additional_tokens = pickle.load(f)

        if isinstance(additional_tokens, dict):
            new_tokens = list(additional_tokens.values())
        elif isinstance(additional_tokens, list):
            new_tokens = additional_tokens
        else:
            raise ValueError("additional_ids_to_tokens.pkl format not recognized")

        if new_tokens:
            tokenizer.add_tokens(new_tokens)
            model.resize_token_embeddings(len(tokenizer))
            print(f"Added {len(new_tokens)} token extra.")
            # 🔹 Stampa token e relativi ID
            for token in new_tokens:
                token_id = tokenizer.convert_tokens_to_ids(token)
                print(f"Token: {token} -> ID: {token_id}")
    else:
        print("No additional_ids_to_tokens.pkl file.")

    return model, tokenizer, device


model_dir = "" # insert model_dir
model, tokenizer, device = load_model_and_tokenizer(model_dir, device)

def compute_next_state(current_state, word, dfa):
    transitions = dfa['transition_matrix']
    alphabet = dfa['alphabet']

    for ch in word:
        if ch in alphabet:
            idx = alphabet.index(ch)
        else:
            idx = 0
        current_state = transitions[current_state][idx]

    return current_state, dfa["distances"][current_state]

def ramp_push_up(alpha_min, di, T, t, gamma):
    ratio = di / (T - t)
    return alpha_min + (1 - alpha_min) * min(1, ratio ** gamma)
    
def custom_beam_search(model, tokenizer, prompt, dfa, num_beams=3, max_length=50, alpha_min=0.5, gamma=1, eps=0.01, device="cuda"):
    #print(prompt)
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device=device)
    input_len = input_ids.shape[1]
    current_state = dfa["initial_state"]
    beams = [(input_ids, 0.0, 0.0, dfa["distances"][current_state], [current_state], [])]
    
    for t in range(max_length):
        candidates = []
        
        for seq, llm_score, automata_score, dist, state_history, token_history in beams:
            
            if seq[0, -1] == tokenizer.eos_token_id and state_history[-1] in dfa["accepting_states"]:
                candidates.append((seq, llm_score, automata_score, dist, state_history, token_history))
                continue
            
            if t > 0:
                alpha = ramp_push_up(alpha_min, dist, max_length, t, gamma)

            else:
                alpha = alpha_min

            with torch.no_grad():

                outputs = model(seq)             
                next_token_logits = outputs.logits[0, -1, :]

                _, top_indices = torch.topk(next_token_logits, num_beams)


                effective_top_indices = []
                
                if token_history:
                    prev_word = tokenizer.decode(token_history[-1])

                for idx in top_indices:

                    word = tokenizer.decode(idx.item())
                    if "[" in word or "\\" in word or "\n" in word: continue
                    if token_history and word in prev_word: continue
                    new_state, new_dist = compute_next_state(state_history[-1], word, dfa)

                    if new_dist <= (max_length - t) and (new_state != current_state): # 
                        effective_top_indices.append(idx.item())
                
                state_changers = []
                if (not dfa["min_strings"][state_history[-1]].startswith("@")) and (dfa["min_strings"][state_history[-1]] not in ["", None]) and (len(tokenizer.encode(dfa["min_strings"][state_history[-1]])) > 0):
                    state_changers = [tokenizer.encode(dfa["min_strings"][state_history[-1]])[0]]
                
                for changer in state_changers:

                    word = tokenizer.decode(changer)
                    if token_history and word in prev_word: continue
                    new_state, new_dist = compute_next_state(state_history[-1], word, dfa)

                    if (new_dist <= (max_length - t)) and (changer not in effective_top_indices) and (new_state != current_state) and (new_dist < dist):
                        effective_top_indices.append(changer)

                if (tokenizer.eos_token_id in effective_top_indices):
                    new_state, new_dist = compute_next_state(state_history[-1], tokenizer.eos_token, dfa)
                    if (new_state not in dfa["accepting_states"]):
                        effective_top_indices.remove(tokenizer.eos_token_id)

                elif (tokenizer.eos_token_id not in effective_top_indices) and (state_history[-1] in dfa["accepting_states"]):
                    effective_top_indices.append(tokenizer.eos_token_id)
                
                llm_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)

                for idx in effective_top_indices:                
                    current_llm_score = llm_score - np.log(llm_probs[idx].item())

                    if idx in state_changers:
                        current_automata_score = automata_score - torch.log(
                            alpha * torch.max(llm_probs) + (1 - alpha + alpha * eps) * llm_probs[idx]
                        )             
                    else:
                        current_automata_score = automata_score - np.log(llm_probs[idx].item())
                    
                    word = tokenizer.decode(idx)
                    new_state, new_dist = compute_next_state(state_history[-1], word, dfa)

                    candidates.append((torch.cat([seq, torch.tensor([[idx]], device=seq.device)], dim=1), current_llm_score, current_automata_score, new_dist, state_history + [new_state], token_history + [idx]))

        beams = sorted(candidates, key=lambda x: x[2])[:num_beams]
    
    beams = sorted(candidates, key=lambda x: x[1])
    best_seq = beams[0][0]

    return best_seq

def extract_concepts(text):
    pattern = r'(.*?)(<\|.*?\|>)(.*)'
    concepts = []
    
    while text:
        match = re.match(pattern, text, re.DOTALL)
        if match:
            if match.group(1):
                concepts.append(match.group(1))
            
            concepts.append(match.group(2))
            
            text = match.group(3)
        else:
            if text:
                concepts.append(text)
            break
    
    return concepts

symbols_ngram = list(string.ascii_lowercase + string.ascii_uppercase + "'")
char_group = None
for s in symbols_ngram:
    atom = CAtom(s)
    char_group = atom if char_group is None else CDisj(char_group, atom)
punct_ngram = list(" " + "," + ".")
punct_group = None
for s in punct_ngram:
    atom = CAtom(s)
    punct_group = atom if punct_group is None else CDisj(punct_group, atom)
ngram = CConcat(CStar(punct_group), 
                CConcat(CConcat(CConcat(char_group, CStar(char_group)), 
                CStar(punct_group)), CStar(char_group)))

# [a-zA-Z0-9 ,]+
symbols_sentence = list(string.ascii_lowercase + string.ascii_uppercase + "'")
char_group = None
for s in symbols_sentence:
    atom = CAtom(s)
    char_group = atom if char_group is None else CDisj(char_group, atom)

punct_sentence = list(" " + ",")
punct_group = None
for s in punct_sentence:
    atom = CAtom(s)
    punct_group = atom if punct_group is None else CDisj(punct_group, atom)

all_chars = CDisj(char_group, punct_group)
content = CConcat(CStar(all_chars), char_group)
punct = CDisj(CDisj(CAtom("."), CAtom("!")), CAtom("?"))

sentence = CConcat(
    CConcat(
        CConcat(
            CStar(punct_group),
            content
        ),
        CStar(punct_group)
    ),
    punct
)

# [a-zA-Z0-9]+
symbols_word = list(string.ascii_lowercase + string.ascii_uppercase + "'")
char_group = None
for s in symbols_word:
    atom = CAtom(s)
    char_group = atom if char_group is None else CDisj(char_group, atom)
punct_end_group = CDisj(CAtom("."), CAtom(","))

word = CConcat(
    CConcat(char_group, CStar(char_group)),
    CStar(punct_end_group)
)

def literal(s: str):
    expr = None
    for ch in s:
        atom = CAtom(ch)
        expr = atom if expr is None else CConcat(expr, atom)
    return expr

def build_regex_from_line(line: str, eos):
    parts = []
    tokens = extract_concepts(line)
    if tokens[-1] == "\n":
        tokens = tokens[:-1]
    elif tokens[-1][-2:] == "\n":
        tokens[-1] = tokens[-1][:-2]
    
    for tok in tokens:
        if tok == "<|infill_ngram|>":
            parts.append(ngram)
        elif tok == "<|infill_word|>":
            parts.append(word)
        elif tok == "<|infill_sentence|>":
            parts.append(sentence)
        elif tok != "":
            parts.append(literal(tok))

    parts.append(literal(eos))

    expr = None
    for p in parts:
        expr = p if expr is None else CConcat(expr, p)
    return expr

def convert_dfa(dfa, tokenizer):
    transition_dict = dfa.delta
    accepting_states = dfa.Final
    states = set(transition_dict.keys())
    alphabet = set()
    
    for state_transitions in transition_dict.values():
        alphabet.update(state_transitions.keys())
    
    alphabet = sorted([char for char in alphabet if isinstance(char, str)])
    
    char_to_index = {char: idx for idx, char in enumerate(alphabet)}
    
    initial_state = 0
    
    deadlock_states = set()
    for state in states:
        if state not in accepting_states:
            transitions = transition_dict[state]
            target_states = set(transitions.values())
            if len(target_states) == 1 and state in target_states:
                deadlock_states.add(state)
    
    num_states = len(states)
    num_chars = len(alphabet)
    
    transition_matrix = [[2] * num_chars for _ in range(num_states)]
    
    for state in states:
        state_transitions = transition_dict[state]
        for char, target_state in state_transitions.items():
            if char in char_to_index:
                char_idx = char_to_index[char]
                transition_matrix[state][char_idx] = target_state
    
    states_list = sorted(states)
    
    min_strings = calculate_min_strings(transition_dict, accepting_states, states_list, alphabet, char_to_index)
    distances = calculate_token_distances_from_strings(min_strings, tokenizer)
    
    return {
        'states': states_list,
        'transition_matrix': transition_matrix,
        'accepting_states': list(accepting_states),
        'deadlock_states': list(deadlock_states),
        'alphabet': alphabet,
        'initial_state': initial_state,
        'distances': distances, 
        'min_strings': min_strings
    }

def calculate_min_strings(transition_dict, accepting_states, states_list, alphabet, char_to_index):
    reverse_graph = {state: [] for state in states_list}
    
    for state in states_list:
        transitions = transition_dict[state]
        for char, target_state in transitions.items():
            if char in char_to_index:
                reverse_graph[target_state].append((state, char))
    
    min_strings = {}
    for state in states_list:
        if state in accepting_states:
            min_strings[state] = ""
        else:
            min_strings[state] = None
    
    pq = []
    for state in accepting_states:
        heapq.heappush(pq, (0, state, ""))
    
    visited = set()
    
    while pq:
        length, current_state, current_string = heapq.heappop(pq)
        
        if current_state in visited:
            continue
        
        visited.add(current_state)
        min_strings[current_state] = current_string
        
        predecessor_info = {}
        for predecessor, char in reverse_graph[current_state]:
            if predecessor not in visited:
                if predecessor not in predecessor_info:
                    predecessor_info[predecessor] = []
                predecessor_info[predecessor].append(char)
        
        for predecessor, chars in predecessor_info.items():
            if len(chars) == 1:
                chosen_char = chars[0]
            else:
                chosen_char = "." #random.choice(chars) #"@" # random.choice(alphabet) # "."
            
            new_string = chosen_char + current_string
            new_length = len(new_string)
            heapq.heappush(pq, (new_length, predecessor, new_string))
    
    return min_strings

def calculate_token_distances_from_strings(min_strings, tokenizer):
    distances = {}
    for state, min_string in min_strings.items():
        
        if min_string is None:
            distances[state] = float('inf')
        else:
            tokens = tokenizer.encode(min_string, add_special_tokens=False)
            distances[state] = len(tokens)
    
    return distances


import json
import os
from tqdm import tqdm
from itertools import product
import pickle
import base64

def process_json(input_json: str, output_dir: str, beam_sizes=None, alpha_values=None, gamma=1.0):
    if beam_sizes is None:
        beam_sizes = [32, 64, 128, 256]
    if alpha_values is None:
        alpha_values = [0.1, 0.25, 0.5, 0.75]
    
    os.makedirs(output_dir, exist_ok=True)
    
    if isinstance(input_json, str):
        with open(input_json, "r", encoding="utf-8") as f:
            data = json.load(f)
        for item in data:
            item["source_file"] = os.path.basename(input_json)
        input_files = [input_json]
    elif isinstance(input_json, list):
        data = []
        for file_path in input_json:
            with open(file_path, "r", encoding="utf-8") as f:
                file_data = json.load(f)
                for item in file_data:
                    item["source_file"] = os.path.basename(file_path)
                data.extend(file_data)
        input_files = input_json
    else:
        raise ValueError("input_json must be path or path list")
        
    dfa_output_file = os.path.join(output_dir, "dfa_cache_complete.json")
    stories_with_dfa = []
    
    if os.path.exists(dfa_output_file):
        print(f"Charging DFA from cache: {dfa_output_file}")
        try:
            stories_with_dfa = load_dfa_from_cache(dfa_output_file)
            
        except Exception as e:
            print(e)
            stories_with_dfa = []
        
    stories_by_source = {}
    for story in stories_with_dfa:
        source_file = story["source_file"]
        if source_file not in stories_by_source:
            stories_by_source[source_file] = []
        stories_by_source[source_file].append(story)
        
    param_combinations = list(product(beam_sizes, alpha_values))
    
    for beams, alpha in param_combinations:
        print(f"\nProcessing beams={beams}, alpha={alpha}, gamma={gamma}")
        
        for source_file, source_stories in stories_by_source.items():
            base_name = source_file.replace('.json', '')
            alpha_str = str(alpha).replace('.', '')
            gamma_str = str(gamma).replace('.', '')
            output_file = os.path.join(output_dir, f"{base_name}_beams_{beams}_alpha_{alpha_str}_gamma_{gamma_str}.json")
            
            if os.path.exists(output_file):
                print(f"Existing file {output_file}, skipping...")
                continue
            
            results = {}
            
            for story_data in tqdm(source_stories, desc=f"{base_name} - Beams={beams}, alpha={alpha}"):
                i = story_data["id"]
                title = story_data["title"]
                prompt = story_data["prompt"]
                dfa = story_data["dfa"]
                max_len = len(tokenizer.encode(prompt))

                concepts = extract_concepts(story_data["prompt"])
                for c in concepts:
                    if c == "<|infill_ngram|>":
                        max_len += 9
                    elif c == "<|infill_word|>":
                        max_len += 4
                    elif c == "<|infill_sentence|>":
                        max_len += 16

                try:
                    result = custom_beam_search(
                        model, tokenizer, prompt, dfa,
                        num_beams=beams, max_length=max_len,
                        alpha_min=alpha, gamma=gamma, eps=0.01, device=device
                    )

                    decoded = tokenizer.decode(
                        result[0][len(tokenizer.encode(prompt)):],
                        skip_special_tokens=True
                    )

                    results[i] = {
                        "title": title,
                        "generated": decoded,
                        "params": {
                            "beams": beams,
                            "alpha": alpha,
                            "gamma": gamma
                        },
                        "source_file": source_file
                    }

                except Exception as e:
                    print(e)
                    continue
            
            if results:
                with open(output_file, "w", encoding="utf-8") as out:
                    json.dump(results, out, ensure_ascii=False, indent=2)
                
def load_dfa_from_cache(cache_file):
    if not os.path.exists(cache_file):
        return []
    
    try:
        with open(cache_file, "r", encoding="utf-8") as f:
            dfa_cache = json.load(f)
        
        stories_with_dfa = []
        for item in dfa_cache:
            dfa_deserialized = pickle.loads(base64.b64decode(item["dfa"]))
            
            stories_with_dfa.append({
                "id": item["id"],
                "title": item["title"],
                "prompt": item["prompt"],
                "dfa": dfa_deserialized,
                "max_length": item["max_length"],
                "source_file": item.get("source_file", "unknown")
            })
        
        return stories_with_dfa
    
    except Exception as e:
        print(e)
        return []


if __name__ == "__main__":
    process_json(
        input_json= [],
        output_dir="../data/text_infilling/",
        beam_sizes=[16, 32, 64],
        alpha_values=[0.0, 0.25, 0.5, 0.75],
        gamma=1.0
    )